"""
拿到salary_Data2.csv 构建线性模型
1. 数据准备（读取数据）
2. 整理输入集，输出集
3. 构建模型
4. 训练模型
5. 测试模型
6. 模型可视化
"""
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.linear_model as lm

# 1. 数据准备（读取数据）
data = pd.read_csv('Salary_Data2.csv')
# print(data)
# 绘制散点图
# plt.scatter(data['YearsExperience'],data['Salary'])
# plt.show()

# 2. 整理输入集，输出集
train_x = data.iloc[:,:-1]
train_y = data.iloc[:,-1]

# 3. 构建模型
model = lm.LinearRegression()
# 4. 训练模型
model.fit(train_x,train_y)
# 5. 测试模型
pred_train_y = model.predict(train_x)
# 6. 模型可视化
plt.plot(train_x,pred_train_y,c='orangered')
plt.scatter(train_x,train_y,c='dodgerblue',s=50)
plt.show()



